package org.joone.engine.extenders;

import org.joone.engine.Matrix;

/* loaded from: input_file:org/joone/engine/extenders/BatchModeExtender.class */
public class BatchModeExtender extends UpdateWeightExtender {
    private Matrix theMatrix;
    private int theBatchSize = -1;
    private int theRows = -1;
    private int theColumns = -1;
    private int theCounter = 0;

    @Override // org.joone.engine.extenders.LearnerExtender
    public void postBiasUpdate(double[] dArr) {
        if (storeWeightsBiases()) {
            for (int i = 0; i < this.theRows; i++) {
                double[] dArr2 = this.theMatrix.value[i];
                dArr2[0] = dArr2[0] + this.theMatrix.delta[i][0];
            }
            getLearner().getLayer().setBias((Matrix) this.theMatrix.clone());
            resetDelta(this.theMatrix);
            this.theCounter = 0;
        }
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void postWeightUpdate(double[] dArr, double[] dArr2) {
        if (storeWeightsBiases()) {
            for (int i = 0; i < this.theRows; i++) {
                for (int i2 = 0; i2 < this.theColumns; i2++) {
                    double[] dArr3 = this.theMatrix.value[i];
                    int i3 = i2;
                    dArr3[i3] = dArr3[i3] + this.theMatrix.delta[i][i2];
                }
            }
            getLearner().getSynapse().setWeights((Matrix) this.theMatrix.clone());
            resetDelta(this.theMatrix);
            this.theCounter = 0;
        }
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void preBiasUpdate(double[] dArr) {
        if (this.theRows != getLearner().getLayer().getRows()) {
            initiateNewBatch();
        }
        this.theCounter++;
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void preWeightUpdate(double[] dArr, double[] dArr2) {
        if (this.theRows != getLearner().getSynapse().getInputDimension() || this.theColumns != getLearner().getSynapse().getOutputDimension()) {
            initiateNewBatch();
        }
        this.theCounter++;
    }

    @Override // org.joone.engine.extenders.UpdateWeightExtender
    public void updateBias(int i, double d) {
        double[] dArr = this.theMatrix.delta[i];
        dArr[0] = dArr[0] + d;
    }

    @Override // org.joone.engine.extenders.UpdateWeightExtender
    public void updateWeight(int i, int i2, double d) {
        double[] dArr = this.theMatrix.delta[i];
        dArr[i2] = dArr[i2] + d;
    }

    protected void resetDelta(Matrix matrix) {
        for (int i = 0; i < matrix.delta.length; i++) {
            for (int i2 = 0; i2 < matrix.delta[0].length; i2++) {
                matrix.delta[i][i2] = 0.0d;
            }
        }
    }

    protected void initiateNewBatch() {
        if (getLearner().getLayer() != null) {
            this.theRows = getLearner().getLayer().getRows();
            this.theMatrix = (Matrix) getLearner().getLayer().getBias().clone();
        } else if (getLearner().getSynapse() != null) {
            this.theRows = getLearner().getSynapse().getInputDimension();
            this.theColumns = getLearner().getSynapse().getOutputDimension();
            this.theMatrix = (Matrix) getLearner().getSynapse().getWeights().clone();
        }
        resetDelta(this.theMatrix);
        this.theCounter = 0;
    }

    public void setBatchSize(int i) {
        this.theBatchSize = i;
    }

    public int getBatchSize() {
        return this.theBatchSize < 0 ? getLearner().getMonitor().getBatchSize() : this.theBatchSize;
    }

    @Override // org.joone.engine.extenders.UpdateWeightExtender
    public boolean storeWeightsBiases() {
        return this.theCounter >= getBatchSize();
    }
}
